{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "faf25b25-ffdc-4237-8f04-2a4857ad545e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict({'weight': tensor([[ 0.0954,  0.1446, -0.0513, -0.2403, -0.1528, -0.2150, -0.0533, -0.1819]]), 'bias': tensor([-0.1339])})\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))\n",
    "X = torch.rand(size=(2, 4))\n",
    "net(X)\n",
    "print(net[2].state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "93d57c7a-4973-4f46-8baa-501eafe0f44a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.nn.parameter.Parameter'>\n",
      "Parameter containing:\n",
      "tensor([-0.1339], requires_grad=True)\n",
      "tensor([-0.1339])\n"
     ]
    }
   ],
   "source": [
    "print(type(net[2].bias))\n",
    "print(net[2].bias)\n",
    "print(net[2].bias.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8566abb2-a77c-4355-b8ff-fc6096b4521f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[2].weight.grad == None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "97b36db5-3de6-45c2-9e21-8dcc4c29f665",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n",
      "('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n"
     ]
    }
   ],
   "source": [
    "print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n",
    "print(*[(name, param.shape) for name, param in net.named_parameters()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7448b049-6bc5-401f-bdbb-8aada21f3899",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.1339])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.state_dict()['2.bias'].data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "83e7a80d-33fd-479a-aedd-508ccbc1367e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Sequential(\n",
      "    (block 0): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "    (block 1): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "    (block 2): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "    (block 3): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "  )\n",
      "  (1): Linear(in_features=4, out_features=1, bias=True)\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([ 0.0439, -0.1959,  0.1394,  0.4716, -0.4470, -0.1951, -0.2275,  0.3053])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def block1():\n",
    "    return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),\n",
    "                         nn.Linear(8, 4), nn.ReLU())\n",
    "\n",
    "def block2():\n",
    "    net = nn.Sequential()\n",
    "    for i in range(4):\n",
    "        # 在这里嵌套\n",
    "        net.add_module(f'block {i}', block1())\n",
    "    return net\n",
    "\n",
    "rgnet = nn.Sequential(block2(), nn.Linear(4, 1))\n",
    "rgnet(X)\n",
    "print(rgnet)\n",
    "rgnet[0][1][0].bias.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "44640d56-3dce-45f3-8b56-91c576618e00",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1887],\n",
       "        [-0.1887]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rgnet(X)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ae71016b-9114-476a-8efe-06234c98048b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([-0.0145,  0.0039, -0.0058, -0.0088]), tensor(0.))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def init_normal(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.normal_(m.weight, mean=0, std=0.01)\n",
    "        nn.init.zeros_(m.bias)\n",
    "net.apply(init_normal)\n",
    "net[0].weight.data[0], net[0].bias.data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9aaba219-63e3-4291-9da7-4b1860dfc807",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.3073,  0.0355, -0.3025,  0.1403],\n",
       "        [ 0.6706, -0.5022, -0.1688, -0.4834],\n",
       "        [-0.1179, -0.5042, -0.0892,  0.0189],\n",
       "        [-0.3264,  0.4265,  0.6790, -0.3737],\n",
       "        [-0.4732,  0.2356,  0.3825,  0.2613],\n",
       "        [-0.0077,  0.2238,  0.6168, -0.2916],\n",
       "        [-0.5646, -0.0327,  0.5411, -0.4785],\n",
       "        [ 0.3378, -0.3351, -0.6499, -0.5748]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "af1f1ed7-e865-4fb4-8e2b-9f0e3f3daf76",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([1., 1., 1., 1.]), tensor(0.))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def init_constant(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.constant_(m.weight, 1)\n",
    "        nn.init.zeros_(m.bias)\n",
    "net.apply(init_constant)\n",
    "net[0].weight.data[0], net[0].bias.data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "64c96b14-f5e8-43e2-b32e-515de45b03d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-0.3073,  0.0355, -0.3025,  0.1403])\n",
      "tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n"
     ]
    }
   ],
   "source": [
    "def init_xavier(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.xavier_uniform_(m.weight)\n",
    "def init_42(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.constant_(m.weight, 42)\n",
    "\n",
    "net[0].apply(init_xavier)\n",
    "net[2].apply(init_42)\n",
    "print(net[0].weight.data[0])\n",
    "print(net[2].weight.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "63821c4b-9a75-42cc-985f-0e855256ebd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Init ('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n",
      "Init ('weight', torch.Size([1, 8])) ('bias', torch.Size([1]))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[ 9.8862, -0.0000,  0.0000, -5.7009],\n",
       "        [-0.0000,  0.0000,  5.8780,  0.0000]], grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def my_init(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        print(\"Init\", *[(name, param.shape)\n",
    "                        for name, param in m.named_parameters()])\n",
    "        nn.init.uniform_(m.weight, -10, 10)\n",
    "        m.weight.data *= m.weight.data.abs() >= 5\n",
    "\n",
    "net.apply(my_init)\n",
    "net[0].weight[:2]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "dd1f6a40-9680-45f1-9c44-ffcf6e341bc3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([42.0000,  1.0000,  1.0000, -4.7009])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data[:] += 1\n",
    "net[0].weight.data[0, 0] = 42\n",
    "net[0].weight.data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "9e487617-f2c1-427e-a141-125752525e20",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([True, True, True, True, True, True, True, True])\n",
      "tensor([True, True, True, True, True, True, True, True])\n"
     ]
    }
   ],
   "source": [
    "# 我们需要给共享层一个名称，以便可以引用它的参数\n",
    "shared = nn.Linear(8, 8)\n",
    "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),\n",
    "                    shared, nn.ReLU(),\n",
    "                    shared, nn.ReLU(),\n",
    "                    nn.Linear(8, 1))\n",
    "net(X)\n",
    "# 检查参数是否相同\n",
    "print(net[2].weight.data[0] == net[4].weight.data[0])\n",
    "net[2].weight.data[0, 0] = 100\n",
    "# 确保它们实际上是同一个对象，而不只是有相同的值\n",
    "print(net[2].weight.data[0] == net[4].weight.data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf6dbdf4-20a1-436b-8e37-0f6f15b2161f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
